{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"HF_ENDPOINT\"] = \"https://hf-mirror.com\"\n",
    "# os.environ[\"HTTP_PROXY\"] = \"http://127.0.0.1:7890\"\n",
    "# os.environ[\"HTTPS_PROXY\"] = \"http://127.0.0.1:7890\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据集加载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_files = {\"train\": \"./data/train.csv\"}\n",
    "\n",
    "dataset = load_dataset(\"csv\", data_files=data_files, num_proc=8)\n",
    "train_and_test = dataset[\"train\"].train_test_split(test_size=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['id', 'comment_text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'],\n",
      "        num_rows: 127656\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['id', 'comment_text', 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'],\n",
      "        num_rows: 31915\n",
      "    })\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "print(train_and_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "label2id = {\n",
    "    \"toxic\": 0,\n",
    "    \"severe_toxic\": 1,\n",
    "    \"obscene\": 2,\n",
    "    \"threat\": 3,\n",
    "    \"insult\": 4,\n",
    "    \"identity_hate\": 5,\n",
    "}\n",
    "\n",
    "id2label = {\n",
    "    0: \"toxic\",\n",
    "    1: \"severe_toxic\",\n",
    "    2: \"obscene\",\n",
    "    3: \"threat\",\n",
    "    4: \"insult\",\n",
    "    5: \"identity_hate\",\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"google-bert/bert-base-uncased\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenized_fn(example):\n",
    "    tokenized_example = tokenizer(\n",
    "        example[\"comment_text\"],\n",
    "        truncation=True,\n",
    "        padding=\"max_length\",\n",
    "        max_length=64,\n",
    "    )\n",
    "    return tokenized_example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_labels(label1, label2, label3, label4, label5, label6):\n",
    "    labels = []\n",
    "\n",
    "    for i in range(len(label1)):\n",
    "        labels.append(\n",
    "            [\n",
    "                float(label1[i]),\n",
    "                float(label2[i]),\n",
    "                float(label3[i]),\n",
    "                float(label4[i]),\n",
    "                float(label5[i]),\n",
    "                float(label6[i]),\n",
    "            ]\n",
    "        )\n",
    "\n",
    "    return {\"labels\": labels}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def datapipe(dataset):\n",
    "    dataset = dataset.map(tokenized_fn, batched=True)\n",
    "    dataset = dataset.map(\n",
    "        gen_labels,\n",
    "        input_columns=[\n",
    "            \"toxic\",\n",
    "            \"severe_toxic\",\n",
    "            \"obscene\",\n",
    "            \"threat\",\n",
    "            \"insult\",\n",
    "            \"identity_hate\",\n",
    "        ],\n",
    "        batched=True,\n",
    "    )\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 127656/127656 [00:19<00:00, 6666.37 examples/s]\n",
      "Map: 100%|██████████| 31915/31915 [00:04<00:00, 6698.83 examples/s]\n",
      "Map: 100%|██████████| 127656/127656 [00:00<00:00, 163420.93 examples/s]\n",
      "Map: 100%|██████████| 31915/31915 [00:00<00:00, 213154.46 examples/s]\n"
     ]
    }
   ],
   "source": [
    "train_and_test = datapipe(train_and_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_and_test = train_and_test.select_columns(\n",
    "    [\n",
    "        \"input_ids\",\n",
    "        \"attention_mask\",\n",
    "        \"labels\",\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 127656\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 31915\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_and_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': [101, 1045, 2134, 1005, 1056, 5382, 1998, 1045, 2572, 2047, 2182, 2074, 2893, 1996, 6865, 1997, 2009, 1012, 4283, 2005, 1996, 4641, 2039, 2295, 999, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}\n"
     ]
    }
   ],
   "source": [
    "print(train_and_test[\"train\"][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSequenceClassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForSequenceClassification.from_pretrained(\n",
    "    \"google-bert/bert-base-uncased\",\n",
    "    num_labels=len(id2label),\n",
    "    id2label=id2label,\n",
    "    label2id=label2id,\n",
    "    problem_type=\"multi_label_classification\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.requires_grad = False\n",
    "model.classifier.requires_grad = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 评价指标"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(x):\n",
    "    return 1 / (1 + np.exp(-x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    preds = sigmoid(logits)\n",
    "    preds = (preds > 0.5).astype(int).reshape(-1)\n",
    "    labels = labels.reshape(-1)\n",
    "    accuracy = accuracy_score(labels, preds)\n",
    "    f1 = f1_score(labels, preds, average=\"macro\")\n",
    "    return {\n",
    "        \"accuracy\": accuracy,\n",
    "        \"f1\": f1,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-06-13 15:41:45.986988: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2024-06-13 15:41:46.030624: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-06-13 15:41:46.030653: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-06-13 15:41:46.030684: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-06-13 15:41:46.038994: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-06-13 15:41:46.829561: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "from transformers import TrainingArguments, Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./output/\",\n",
    "    eval_strategy=\"epoch\",\n",
    "    per_device_train_batch_size=128,\n",
    "    per_device_eval_batch_size=128,\n",
    "    weight_decay=0.01,\n",
    "    num_train_epochs=5,\n",
    "    learning_rate=2e-5,\n",
    "    save_strategy=\"epoch\",\n",
    "    load_best_model_at_end=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': [101, 1045, 2123, 1005, 1056, 2113, 2040, 2023, 10812, 3559, 2003, 1010, 2021, 1045, 2123, 1005, 1056, 2903, 1996, 1996, 5464, 2003, 2025, 14485, 1012, 2009, 2003, 2899, 7058, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}\n"
     ]
    }
   ],
   "source": [
    "print(train_and_test[\"test\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 4.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(\n",
    "    args=training_args,\n",
    "    model=model,\n",
    "    train_dataset=train_and_test[\"train\"],\n",
    "    eval_dataset=train_and_test[\"test\"],\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='500' max='250' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [250/250 22:28]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.05397845804691315,\n",
       " 'eval_accuracy': 0.9817588385816491,\n",
       " 'eval_f1': 0.8570132571465302,\n",
       " 'eval_runtime': 100.4745,\n",
       " 'eval_samples_per_second': 317.643,\n",
       " 'eval_steps_per_second': 2.488}"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='4990' max='4990' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [4990/4990 1:50:45, Epoch 5/5]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.046600</td>\n",
       "      <td>0.042239</td>\n",
       "      <td>0.984312</td>\n",
       "      <td>0.882476</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.036300</td>\n",
       "      <td>0.041644</td>\n",
       "      <td>0.984067</td>\n",
       "      <td>0.881441</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.029900</td>\n",
       "      <td>0.043375</td>\n",
       "      <td>0.983816</td>\n",
       "      <td>0.883072</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.024800</td>\n",
       "      <td>0.048329</td>\n",
       "      <td>0.983503</td>\n",
       "      <td>0.882226</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.021100</td>\n",
       "      <td>0.049663</td>\n",
       "      <td>0.983153</td>\n",
       "      <td>0.881117</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=4990, training_loss=0.03134886850574929, metrics={'train_runtime': 6646.9584, 'train_samples_per_second': 96.026, 'train_steps_per_second': 0.751, 'total_flos': 2.099306947802112e+16, 'train_loss': 0.03134886850574929, 'epoch': 5.0})"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "answer = (sigmoid(trainer.predict(train_and_test[\"test\"]).predictions) > 0.5).astype(\n",
    "    int\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 0, 1, 0, 1, 0])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "answer[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1.0, 0.0, 1.0, 0.0, 1.0, 1.0]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_and_test[\"test\"][\"labels\"][2]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
